{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7e569af4",
   "metadata": {},
   "source": [
    "# Tutorial 06, case 7a: Stokes problem with Dirichlet control\n",
    "\n",
    "In this tutorial we solve the optimal control problem\n",
    "\n",
    "$$\\min J(y, u) = \\frac{1}{2} \\int_{\\Omega_{obs}} |\\text{curl} v|^2 dx + \\frac{\\alpha}{2} \\int_{\\Gamma_C} |\\nabla_{\\mathbf{t}} u|^2 ds$$\n",
    "s.t.\n",
    "$$\\begin{cases}\n",
    "- \\nu \\Delta v + \\nabla p = f                 & \\text{in } \\Omega\\\\\n",
    "                       \\text{div} v = 0       & \\text{in } \\Omega\\\\\n",
    "                                  v = g       & \\text{on } \\Gamma_{in}\\\\\n",
    "                                  v = 0       & \\text{on } \\Gamma_{w}\\\\\n",
    "                 v \\cdot \\mathbf{n} = u       & \\text{on } \\Gamma_{C}\\\\\n",
    "                 v \\cdot \\mathbf{t} = 0       & \\text{on } \\Gamma_{C}\\\\\n",
    "                 v \\cdot \\mathbf{n} = 0       & \\text{on } \\Gamma_{s}\\\\\n",
    "  \\nu \\partial_n v \\cdot \\mathbf{t} = 0       & \\text{on } \\Gamma_{s}\\\\\n",
    "             p n - \\nu \\partial_n v = 0       & \\text{on } \\Gamma_{N}\n",
    "\\end{cases}$$\n",
    "\n",
    "where\n",
    "$$\\begin{align*}\n",
    "& \\Omega                      & \\text{unit square}\\\\\n",
    "& \\Gamma_{in}                 & \\text{has boundary id 1}\\\\\n",
    "& \\Gamma_{s}                  & \\text{has boundary id 2}\\\\\n",
    "& \\Gamma_{N}                  & \\text{has boundary id 3}\\\\\n",
    "& \\Gamma_{C}                  & \\text{has boundary id 4}\\\\\n",
    "& \\Gamma_{w}                  & \\text{has boundary id 5}\\\\\n",
    "& u \\in L^2(\\Gamma_C)         & \\text{control variable}\\\\\n",
    "& v \\in [H^1(\\Omega)]^2       & \\text{state velocity variable}\\\\\n",
    "& p \\in L^2(\\Omega)           & \\text{state pressure variable}\\\\\n",
    "& \\alpha > 0                  & \\text{penalization parameter}\\\\\n",
    "& v_d                         & \\text{desired state}\\\\\n",
    "& f                           & \\text{forcing term}\\\\\n",
    "& g                           & \\text{inlet profile}\\\\\n",
    "\\end{align*}$$\n",
    "using an adjoint formulation solved by a one shot approach.\n",
    "\n",
    "The test case is from section 5 of\n",
    "```\n",
    "F. Negri, A. Manzoni and G. Rozza. Reduced basis approximation of parametrized optimal flow control problems for the Stokes equations. Computer and Mathematics with Applications, 69(4):319-336, 2015.\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d8f7489",
   "metadata": {},
   "outputs": [],
   "source": [
    "import dolfinx.fem\n",
    "import dolfinx.io\n",
    "import dolfinx.mesh\n",
    "import gmsh\n",
    "import mpi4py.MPI\n",
    "import numpy as np\n",
    "import numpy.typing\n",
    "import petsc4py.PETSc\n",
    "import ufl\n",
    "import viskex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d5489a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import multiphenicsx.fem\n",
    "import multiphenicsx.fem.petsc"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb340273",
   "metadata": {},
   "source": [
    "### Geometrical parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "943839f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "L1 = 0.9\n",
    "L2 = 0.35\n",
    "L3 = 0.55\n",
    "L4 = 0.2\n",
    "H = 1.0\n",
    "r = 0.1\n",
    "mesh_size = 0.025"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c1000e7",
   "metadata": {},
   "source": [
    "### Mesh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4331002",
   "metadata": {},
   "outputs": [],
   "source": [
    "gmsh.initialize()\n",
    "gmsh.model.add(\"mesh\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb4d59ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "p0 = gmsh.model.geo.addPoint(0.0, 0.0, 0.0, mesh_size)\n",
    "p1 = gmsh.model.geo.addPoint(L1, 0.0, 0.0, mesh_size)\n",
    "p2 = gmsh.model.geo.addPoint(L1 + L2, 0.0, 0.0, mesh_size)\n",
    "p3 = gmsh.model.geo.addPoint(L1 + L2 + L3, 0.0, 0.0, mesh_size)\n",
    "p4 = gmsh.model.geo.addPoint(L1 + L2 + L3 + L4, 0.0, 0.0, mesh_size)\n",
    "p5 = gmsh.model.geo.addPoint(L1 + L2 + L3 + L4, H, 0.0, mesh_size)\n",
    "p6 = gmsh.model.geo.addPoint(L1 + L2 + L3, H, 0.0, mesh_size)\n",
    "p7 = gmsh.model.geo.addPoint(L1 + L2, H, 0.0, mesh_size)\n",
    "p8 = gmsh.model.geo.addPoint(L1, H, 0.0, mesh_size)\n",
    "p9 = gmsh.model.geo.addPoint(0.0, H, 0.0, mesh_size)\n",
    "p10 = gmsh.model.geo.addPoint(L1, H / 2, 0.0, mesh_size)\n",
    "p11 = gmsh.model.geo.addPoint(L1, H / 2 + r, 0.0, mesh_size)\n",
    "p12 = gmsh.model.geo.addPoint(L1, H / 2 - r, 0.0, mesh_size)\n",
    "p13 = gmsh.model.geo.addPoint(L1 + L2, H / 2 - r, 0.0, mesh_size)\n",
    "p14 = gmsh.model.geo.addPoint(L1 + L2 + L3, H / 2 - 3 * r, 0.0, mesh_size)\n",
    "p15 = gmsh.model.geo.addPoint(L1 + L2 + L3, H / 2 + 3 * r, 0.0, mesh_size)\n",
    "p16 = gmsh.model.geo.addPoint(L1 + L2, H / 2 + r, 0.0, mesh_size)\n",
    "l0 = gmsh.model.geo.addLine(p0, p1)\n",
    "l1 = gmsh.model.geo.addLine(p1, p2)\n",
    "l2 = gmsh.model.geo.addLine(p2, p3)\n",
    "l3 = gmsh.model.geo.addLine(p3, p4)\n",
    "l4 = gmsh.model.geo.addLine(p4, p5)\n",
    "l5 = gmsh.model.geo.addLine(p5, p6)\n",
    "l6 = gmsh.model.geo.addLine(p6, p7)\n",
    "l7 = gmsh.model.geo.addLine(p7, p8)\n",
    "l8 = gmsh.model.geo.addLine(p8, p9)\n",
    "l9 = gmsh.model.geo.addLine(p9, p0)\n",
    "l10 = gmsh.model.geo.addLine(p12, p13)\n",
    "l11 = gmsh.model.geo.addLine(p13, p14)\n",
    "l12 = gmsh.model.geo.addLine(p14, p15)\n",
    "l13 = gmsh.model.geo.addLine(p15, p16)\n",
    "l14 = gmsh.model.geo.addLine(p16, p11)\n",
    "l15 = gmsh.model.geo.addLine(p13, p16)\n",
    "l16 = gmsh.model.geo.addLine(p1, p12)\n",
    "l17 = gmsh.model.geo.addLine(p11, p8)\n",
    "l18 = gmsh.model.geo.addLine(p2, p13)\n",
    "l19 = gmsh.model.geo.addLine(p16, p7)\n",
    "l20 = gmsh.model.geo.addLine(p3, p14)\n",
    "l21 = gmsh.model.geo.addLine(p15, p6)\n",
    "c0 = gmsh.model.geo.addCircleArc(p11, p10, p12)\n",
    "line_loop_subdomain1 = gmsh.model.geo.addCurveLoop([l0, l16, -c0, l17, l8, l9])\n",
    "line_loop_subdomain2a = gmsh.model.geo.addCurveLoop([l1, l18, -l10, -l16])\n",
    "line_loop_subdomain2b = gmsh.model.geo.addCurveLoop([l7, -l17, -l14, l19])\n",
    "line_loop_subdomain3a = gmsh.model.geo.addCurveLoop([l2, l20, -l11, -l18])\n",
    "line_loop_subdomain3b = gmsh.model.geo.addCurveLoop([l6, -l19, -l13, l21])\n",
    "line_loop_subdomain3c = gmsh.model.geo.addCurveLoop([l3, l4, l5, -l21, -l12, -l20])\n",
    "line_loop_subdomain4 = gmsh.model.geo.addCurveLoop([l11, l12, l13, -l15])\n",
    "subdomain1 = gmsh.model.geo.addPlaneSurface([line_loop_subdomain1])\n",
    "subdomain2a = gmsh.model.geo.addPlaneSurface([line_loop_subdomain2a])\n",
    "subdomain2b = gmsh.model.geo.addPlaneSurface([line_loop_subdomain2b])\n",
    "subdomain3a = gmsh.model.geo.addPlaneSurface([line_loop_subdomain3a])\n",
    "subdomain3b = gmsh.model.geo.addPlaneSurface([line_loop_subdomain3b])\n",
    "subdomain3c = gmsh.model.geo.addPlaneSurface([line_loop_subdomain3c])\n",
    "subdomain4 = gmsh.model.geo.addPlaneSurface([line_loop_subdomain4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fd262ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "gmsh.model.geo.synchronize()\n",
    "gmsh.model.addPhysicalGroup(1, [l9], 1)\n",
    "gmsh.model.addPhysicalGroup(1, [l0, l1, l2, l3, l5, l6, l7, l8], 2)\n",
    "gmsh.model.addPhysicalGroup(1, [l4], 3)\n",
    "gmsh.model.addPhysicalGroup(1, [l10, l14], 4)\n",
    "gmsh.model.addPhysicalGroup(1, [c0, l15], 5)\n",
    "gmsh.model.addPhysicalGroup(2, [subdomain1], 1)\n",
    "gmsh.model.addPhysicalGroup(2, [subdomain2a, subdomain2b], 2)\n",
    "gmsh.model.addPhysicalGroup(2, [subdomain3a, subdomain3b, subdomain3c], 3)\n",
    "gmsh.model.addPhysicalGroup(2, [subdomain4], 4)\n",
    "gmsh.model.mesh.generate(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd40cb3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "mesh, subdomains, boundaries = dolfinx.io.gmshio.model_to_mesh(\n",
    "    gmsh.model, comm=mpi4py.MPI.COMM_WORLD, rank=0, gdim=2)\n",
    "gmsh.finalize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81ad6b69-5ee6-4b2f-8277-3abca703691e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create connectivities required by the rest of the code\n",
    "mesh.topology.create_connectivity(mesh.topology.dim - 1, mesh.topology.dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf84bd00",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define associated measures\n",
    "dx = ufl.Measure(\"dx\", subdomain_data=subdomains)\n",
    "ds = ufl.Measure(\"ds\", subdomain_data=boundaries)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac714914",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Normal and tangent\n",
    "n = ufl.FacetNormal(mesh)\n",
    "t = ufl.as_vector([n[1], -n[0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f78d0922",
   "metadata": {},
   "outputs": [],
   "source": [
    "viskex.dolfinx.plot_mesh(mesh)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e14527b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "viskex.dolfinx.plot_mesh_tags(mesh, subdomains, \"subdomains\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "baf7dbd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "viskex.dolfinx.plot_mesh_tags(mesh, boundaries, \"boundaries\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d4714d9",
   "metadata": {},
   "source": [
    "### Function spaces"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e2e8ebd",
   "metadata": {},
   "outputs": [],
   "source": [
    "Y_velocity = dolfinx.fem.functionspace(mesh, (\"Lagrange\", 2, (mesh.geometry.dim, )))\n",
    "Y_pressure = dolfinx.fem.functionspace(mesh, (\"Lagrange\", 1))\n",
    "U = dolfinx.fem.functionspace(mesh, (\"Lagrange\", 2))\n",
    "L = U.clone()\n",
    "Q_velocity = Y_velocity.clone()\n",
    "Q_pressure = Y_pressure.clone()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad28816c",
   "metadata": {},
   "outputs": [],
   "source": [
    "Y_velocity_0 = Y_velocity.sub(0)\n",
    "Y_velocity_1 = Y_velocity.sub(1)\n",
    "Q_velocity_0 = Q_velocity.sub(0)\n",
    "Q_velocity_1 = Q_velocity.sub(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c035c21d",
   "metadata": {},
   "source": [
    "### Restrictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec4434c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "dofs_Y_velocity = np.arange(0, Y_velocity.dofmap.index_map.size_local + Y_velocity.dofmap.index_map.num_ghosts)\n",
    "dofs_Y_pressure = np.arange(0, Y_pressure.dofmap.index_map.size_local + Y_pressure.dofmap.index_map.num_ghosts)\n",
    "dofs_U = dolfinx.fem.locate_dofs_topological(U, boundaries.dim, boundaries.indices[boundaries.values == 4])\n",
    "dofs_L = dofs_U\n",
    "dofs_Q_velocity = dofs_Y_velocity\n",
    "dofs_Q_pressure = dofs_Y_pressure\n",
    "restriction_Y_velocity = multiphenicsx.fem.DofMapRestriction(Y_velocity.dofmap, dofs_Y_velocity)\n",
    "restriction_Y_pressure = multiphenicsx.fem.DofMapRestriction(Y_pressure.dofmap, dofs_Y_pressure)\n",
    "restriction_U = multiphenicsx.fem.DofMapRestriction(U.dofmap, dofs_U)\n",
    "restriction_L = multiphenicsx.fem.DofMapRestriction(L.dofmap, dofs_L)\n",
    "restriction_Q_velocity = multiphenicsx.fem.DofMapRestriction(Q_velocity.dofmap, dofs_Q_velocity)\n",
    "restriction_Q_pressure = multiphenicsx.fem.DofMapRestriction(Q_pressure.dofmap, dofs_Q_pressure)\n",
    "restriction = [\n",
    "    restriction_Y_velocity, restriction_Y_pressure, restriction_U, restriction_L,\n",
    "    restriction_Q_velocity, restriction_Q_pressure]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cbbc7931",
   "metadata": {},
   "source": [
    "### Trial and test functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ff08560",
   "metadata": {},
   "outputs": [],
   "source": [
    "(v, p) = (ufl.TrialFunction(Y_velocity), ufl.TrialFunction(Y_pressure))\n",
    "(w, q) = (ufl.TestFunction(Y_velocity), ufl.TestFunction(Y_pressure))\n",
    "(u, l) = (ufl.TrialFunction(U), ufl.TrialFunction(L))\n",
    "(r, m) = (ufl.TestFunction(U), ufl.TestFunction(L))\n",
    "(z, b) = (ufl.TrialFunction(Q_velocity), ufl.TrialFunction(Q_pressure))\n",
    "(s, d) = (ufl.TestFunction(Q_velocity), ufl.TestFunction(Q_pressure))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38fdfb3a",
   "metadata": {},
   "source": [
    " ### Problem data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb50cf1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def non_zero_eval(x: np.typing.NDArray[np.float64]) -> np.typing.NDArray[  # type: ignore[no-any-unimported]\n",
    "        petsc4py.PETSc.ScalarType]:\n",
    "    \"\"\"Return the flat velocity profile at the inlet.\"\"\"\n",
    "    values = np.zeros((2, x.shape[1]))\n",
    "    values[0, :] = 2.5\n",
    "    return values\n",
    "\n",
    "\n",
    "nu = 1.\n",
    "alpha = 1.e-2\n",
    "ff = dolfinx.fem.Constant(mesh, tuple(petsc4py.PETSc.ScalarType(0) for _ in range(2)))\n",
    "bc0 = dolfinx.fem.Function(Y_velocity)\n",
    "bc0_component = dolfinx.fem.Function(Y_velocity_0.collapse()[0])\n",
    "bc1 = dolfinx.fem.Function(Y_velocity)\n",
    "bc1.interpolate(non_zero_eval)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d6012b0",
   "metadata": {},
   "source": [
    "### Optimality conditions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "907c589d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def vorticity(v: ufl.Argument, w: ufl.Argument) -> ufl.core.expr.Expr:  # type: ignore[no-any-unimported]\n",
    "    \"\"\"Return the UFL expression corresponding to the inner(curl, curl) operator.\"\"\"\n",
    "    return ufl.inner(ufl.curl(v), ufl.curl(w))\n",
    "\n",
    "\n",
    "def penalty(u: ufl.Argument, r: ufl.Argument) -> ufl.core.expr.Expr:  # type: ignore[no-any-unimported]\n",
    "    \"\"\"Return the UFL expression corresponding to the penalty term.\"\"\"\n",
    "    return alpha * ufl.inner(ufl.dot(ufl.grad(u), t), ufl.dot(ufl.grad(r), t))\n",
    "\n",
    "\n",
    "a = [[vorticity(v, w) * dx(4), None, None, ufl.inner(l, ufl.dot(w, n)) * ds(4),\n",
    "      nu * ufl.inner(ufl.grad(z), ufl.grad(w)) * dx, - ufl.inner(b, ufl.div(w)) * dx],\n",
    "     [None, None, None, None, - ufl.inner(ufl.div(z), q) * dx, None],\n",
    "     [None, None, penalty(u, r) * ds(4), - ufl.inner(l, r) * ds(4), None, None],\n",
    "     [ufl.inner(ufl.dot(v, n), m) * ds(4), None, - ufl.inner(u, m) * ds(4), None, None, None],\n",
    "     [nu * ufl.inner(ufl.grad(v), ufl.grad(s)) * dx, - ufl.inner(p, ufl.div(s)) * dx, None, None, None, None],\n",
    "     [- ufl.inner(ufl.div(v), d) * dx, None, None, None, None, None]]\n",
    "f = [None,\n",
    "     None,\n",
    "     None,\n",
    "     None,\n",
    "     ufl.inner(ff, s) * dx,\n",
    "     None]\n",
    "a[0][0] += dolfinx.fem.Constant(mesh, petsc4py.PETSc.ScalarType(0)) * ufl.inner(v, w) * dx\n",
    "a[4][4] = dolfinx.fem.Constant(mesh, petsc4py.PETSc.ScalarType(0)) * ufl.inner(z, s) * dx\n",
    "f[0] = ufl.inner(dolfinx.fem.Constant(mesh, tuple(petsc4py.PETSc.ScalarType(0) for _ in range(2))), w) * dx\n",
    "f[1] = ufl.inner(dolfinx.fem.Constant(mesh, petsc4py.PETSc.ScalarType(0)), q) * dx\n",
    "f[2] = ufl.inner(dolfinx.fem.Constant(mesh, petsc4py.PETSc.ScalarType(0)), r) * dx\n",
    "f[3] = ufl.inner(dolfinx.fem.Constant(mesh, petsc4py.PETSc.ScalarType(0)), m) * dx\n",
    "f[5] = ufl.inner(dolfinx.fem.Constant(mesh, petsc4py.PETSc.ScalarType(0)), d) * dx\n",
    "a_cpp = dolfinx.fem.form(a)\n",
    "f_cpp = dolfinx.fem.form(f)\n",
    "\n",
    "\n",
    "def bdofs(\n",
    "    space_from: dolfinx.fem.FunctionSpace, space_to: dolfinx.fem.FunctionSpace, idx: int\n",
    ") -> np.typing.NDArray[np.int32]:\n",
    "    \"\"\"Locate DOFs on the boundary `idx`.\"\"\"\n",
    "    return dolfinx.fem.locate_dofs_topological(\n",
    "        (space_from, space_to), mesh.topology.dim - 1, boundaries.indices[boundaries.values == idx])\n",
    "\n",
    "\n",
    "bc = [\n",
    "    dolfinx.fem.dirichletbc(\n",
    "        bc1, bdofs(Y_velocity, bc1.function_space, 1), Y_velocity),\n",
    "    dolfinx.fem.dirichletbc(\n",
    "        bc0_component, bdofs(Y_velocity_1, bc0_component.function_space, 2), Y_velocity_1),\n",
    "    dolfinx.fem.dirichletbc(\n",
    "        bc0_component, bdofs(Y_velocity_0, bc0_component.function_space, 4), Y_velocity_0),\n",
    "    dolfinx.fem.dirichletbc(\n",
    "        bc0, bdofs(Y_velocity, bc0.function_space, 5), Y_velocity),\n",
    "    dolfinx.fem.dirichletbc(\n",
    "        bc0, bdofs(Q_velocity, bc0.function_space, 1), Q_velocity),\n",
    "    dolfinx.fem.dirichletbc(\n",
    "        bc0_component, bdofs(Q_velocity_1, bc0_component.function_space, 2), Q_velocity_1),\n",
    "    dolfinx.fem.dirichletbc(\n",
    "        bc0, bdofs(Q_velocity, bc0.function_space, 4), Q_velocity),\n",
    "    dolfinx.fem.dirichletbc(\n",
    "        bc0, bdofs(Q_velocity, bc0.function_space, 5), Q_velocity)\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5515e6fe",
   "metadata": {},
   "source": [
    "### Solution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfe04119",
   "metadata": {},
   "outputs": [],
   "source": [
    "(v, p) = (dolfinx.fem.Function(Y_velocity), dolfinx.fem.Function(Y_pressure))\n",
    "(u, l) = (dolfinx.fem.Function(U), dolfinx.fem.Function(L))\n",
    "(z, b) = (dolfinx.fem.Function(Q_velocity), dolfinx.fem.Function(Q_pressure))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a830152f",
   "metadata": {},
   "source": [
    "### Cost functional"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79f5c592",
   "metadata": {},
   "outputs": [],
   "source": [
    "J = 0.5 * vorticity(v, v) * dx(4) + 0.5 * penalty(u, u) * ds(4)\n",
    "J_cpp = dolfinx.fem.form(J)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d896949",
   "metadata": {},
   "source": [
    "### Uncontrolled functional value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "099e6491",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract state forms from the optimality conditions\n",
    "a_state = [[ufl.replace(a[i][j], {s: w, d: q}) if a[i][j] is not None else None for j in (0, 1)] for i in (4, 5)]\n",
    "f_state = [ufl.replace(f[i], {s: w, d: q}) for i in (4, 5)]\n",
    "a_state_cpp = dolfinx.fem.form(a_state)\n",
    "f_state_cpp = dolfinx.fem.form(f_state)\n",
    "bc_state = [\n",
    "    dolfinx.fem.dirichletbc(\n",
    "        bc1, bdofs(Y_velocity, bc1.function_space, 1), Y_velocity),\n",
    "    dolfinx.fem.dirichletbc(\n",
    "        bc0_component, bdofs(Y_velocity_1, bc0_component.function_space, 2), Y_velocity_1),\n",
    "    dolfinx.fem.dirichletbc(\n",
    "        bc0, bdofs(Y_velocity, bc0.function_space, 4), Y_velocity),\n",
    "    dolfinx.fem.dirichletbc(\n",
    "        bc0, bdofs(Y_velocity, bc0.function_space, 5), Y_velocity)\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa065d1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assemble the block linear system for the state\n",
    "A_state = multiphenicsx.fem.petsc.assemble_matrix_block(\n",
    "    a_state_cpp, bcs=bc_state, restriction=([restriction[i] for i in (4, 5)], [restriction[j] for j in (0, 1)]))\n",
    "A_state.assemble()\n",
    "F_state = multiphenicsx.fem.petsc.assemble_vector_block(\n",
    "    f_state_cpp, a_state_cpp, bcs=bc_state, restriction=[restriction[i] for i in (4, 5)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d109e851",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Solve\n",
    "vp = multiphenicsx.fem.petsc.create_vector_block(\n",
    "    [f_cpp[j] for j in (0, 1)], restriction=[restriction[j] for j in (0, 1)])\n",
    "ksp = petsc4py.PETSc.KSP()\n",
    "ksp.create(mesh.comm)\n",
    "ksp.setOperators(A_state)\n",
    "ksp.setType(\"preonly\")\n",
    "ksp.getPC().setType(\"lu\")\n",
    "ksp.getPC().setFactorSolverType(\"mumps\")\n",
    "ksp.setFromOptions()\n",
    "ksp.solve(F_state, vp)\n",
    "vp.ghostUpdate(addv=petsc4py.PETSc.InsertMode.INSERT, mode=petsc4py.PETSc.ScatterMode.FORWARD)\n",
    "ksp.destroy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e8a3b0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split the block solution in components\n",
    "with multiphenicsx.fem.petsc.BlockVecSubVectorWrapper(\n",
    "        vp, [c.function_space.dofmap for c in (v, p)], [restriction[j] for j in (0, 1)]) as vp_wrapper:\n",
    "    for vp_wrapper_local, component in zip(vp_wrapper, (v, p)):\n",
    "        with component.x.petsc_vec.localForm() as component_local:\n",
    "            component_local[:] = vp_wrapper_local\n",
    "vp.destroy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd6ee9de",
   "metadata": {},
   "outputs": [],
   "source": [
    "J_uncontrolled = mesh.comm.allreduce(dolfinx.fem.assemble_scalar(J_cpp), op=mpi4py.MPI.SUM)\n",
    "print(\"Uncontrolled J =\", J_uncontrolled)\n",
    "assert np.isclose(J_uncontrolled, 2.9236194)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "788cc135",
   "metadata": {},
   "outputs": [],
   "source": [
    "viskex.dolfinx.plot_vector_field(v, \"uncontrolled state velocity\", glyph_factor=3e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0dee94b",
   "metadata": {},
   "outputs": [],
   "source": [
    "viskex.dolfinx.plot_scalar_field(p, \"uncontrolled state pressure\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "928b6fca",
   "metadata": {},
   "source": [
    "### Optimal control"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9505f239",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assemble the block linear system for the optimality conditions\n",
    "A = multiphenicsx.fem.petsc.assemble_matrix_block(a_cpp, bcs=bc, restriction=(restriction, restriction))\n",
    "A.assemble()\n",
    "F = multiphenicsx.fem.petsc.assemble_vector_block(f_cpp, a_cpp, bcs=bc, restriction=restriction)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e2ac886",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Solve\n",
    "vpulzb = multiphenicsx.fem.petsc.create_vector_block(f_cpp, restriction=restriction)\n",
    "ksp = petsc4py.PETSc.KSP()\n",
    "ksp.create(mesh.comm)\n",
    "ksp.setOperators(A)\n",
    "ksp.setType(\"preonly\")\n",
    "ksp.getPC().setType(\"lu\")\n",
    "ksp.getPC().setFactorSolverType(\"mumps\")\n",
    "ksp.setFromOptions()\n",
    "ksp.solve(F, vpulzb)\n",
    "vpulzb.ghostUpdate(addv=petsc4py.PETSc.InsertMode.INSERT, mode=petsc4py.PETSc.ScatterMode.FORWARD)\n",
    "ksp.destroy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5defd35e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split the block solution in components\n",
    "with multiphenicsx.fem.petsc.BlockVecSubVectorWrapper(\n",
    "        vpulzb, [c.function_space.dofmap for c in (v, p, u, l, z, b)], restriction) as vpulzb_wrapper:\n",
    "    for vpulzb_wrapper_local, component in zip(vpulzb_wrapper, (v, p, u, l, z, b)):\n",
    "        with component.x.petsc_vec.localForm() as component_local:\n",
    "            component_local[:] = vpulzb_wrapper_local\n",
    "vpulzb.destroy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f34bb270",
   "metadata": {},
   "outputs": [],
   "source": [
    "J_controlled = mesh.comm.allreduce(dolfinx.fem.assemble_scalar(J_cpp), op=mpi4py.MPI.SUM)\n",
    "print(\"Optimal J =\", J_controlled)\n",
    "assert np.isclose(J_controlled, 1.71027397)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d8ff315",
   "metadata": {},
   "outputs": [],
   "source": [
    "viskex.dolfinx.plot_vector_field(v, \"state velocity\", glyph_factor=3e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51253a6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "viskex.dolfinx.plot_scalar_field(p, \"state pressure\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b6239f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "viskex.dolfinx.plot_scalar_field(u, \"control\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d2f43eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "viskex.dolfinx.plot_scalar_field(l, \"lambda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af0e2056",
   "metadata": {},
   "outputs": [],
   "source": [
    "viskex.dolfinx.plot_vector_field(z, \"adjoint velocity\", glyph_factor=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0194dcca",
   "metadata": {},
   "outputs": [],
   "source": [
    "viskex.dolfinx.plot_scalar_field(b, \"adjoint pressure\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython"
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
